{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6eb94b72",
   "metadata": {},
   "source": [
    "#### Set environment variables in [.env](.env) for LLM API calling"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "388020c6",
   "metadata": {},
   "source": [
    "### Import Dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11efa138",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.insert(0, \"../../\")\n",
    "import promptwizard\n",
    "from promptwizard.glue.promptopt.instantiate import GluePromptOpt\n",
    "from promptwizard.glue.promptopt.techniques.common_logic import DatasetSpecificProcessing\n",
    "from promptwizard.glue.common.utils.file import save_jsonlist\n",
    "from typing import Any\n",
    "from tqdm import tqdm\n",
    "from re import compile, findall\n",
    "import os\n",
    "from datasets import load_dataset\n",
    "\n",
    "from dotenv import load_dotenv\n",
    "load_dotenv(override = True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "beb14821",
   "metadata": {},
   "source": [
    "### Create a dataset specific class and define the required functions "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5f325d33",
   "metadata": {},
   "outputs": [],
   "source": [
    "class GSM8k(DatasetSpecificProcessing):\n",
    "\n",
    "    def dataset_to_jsonl(self, dataset_jsonl: str, **kwargs: Any) -> None:\n",
    "        def extract_answer_from_output(completion):\n",
    "            # Your functions for metrics and prompt building\n",
    "            ans_re = compile(r\"#### (\\-?[0-9\\.\\,]+)\")\n",
    "            self.INVALID_ANS = \"[invalid]\"\n",
    "\n",
    "            match = ans_re.search(completion)\n",
    "            if match:\n",
    "                match_str = match.group(1).strip()\n",
    "                match_str = match_str.replace(\",\", \"\")\n",
    "                return match_str\n",
    "            else:\n",
    "                return self.INVALID_ANS\n",
    "\n",
    "        examples_set = []\n",
    "\n",
    "        for _, sample in tqdm(enumerate(kwargs[\"dataset\"]), desc=\"Evaluating samples\"):\n",
    "            example = {\n",
    "              DatasetSpecificProcessing.QUESTION_LITERAL: sample['question'],\n",
    "              DatasetSpecificProcessing.ANSWER_WITH_REASON_LITERAL: sample['answer'],\n",
    "              DatasetSpecificProcessing.FINAL_ANSWER_LITERAL: extract_answer_from_output(sample[\"answer\"])\n",
    "            }\n",
    "            examples_set.append(example)\n",
    "\n",
    "        save_jsonlist(dataset_jsonl, examples_set, \"w\")\n",
    "\n",
    "    def extract_final_answer(self, answer: str):\n",
    "        \n",
    "        if not answer:\n",
    "            return self.INVALID_ANS\n",
    "\n",
    "        model_pred = answer.lower()\n",
    "        preds = model_pred.split(self.ANSWER_START.lower())\n",
    "        answer_flag = True if len(preds) > 1 else False\n",
    "\n",
    "        pred = preds[-1].replace(\",\", \"\")\n",
    "        pred = [s for s in findall(r'-?\\d+\\.?\\d*', pred)]\n",
    "\n",
    "        if len(pred) == 0:\n",
    "            return self.INVALID_ANS\n",
    "\n",
    "        if answer_flag:\n",
    "            # choose the first element in list\n",
    "            pred = pred[0]\n",
    "        else:\n",
    "            # choose the last element in list\n",
    "            pred = pred[-1]\n",
    "\n",
    "        # (For arithmetic tasks) if a word ends with period, it will be omitted ...\n",
    "        if pred[-1] == \".\":\n",
    "            pred = pred[:-1]\n",
    "        return pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f384eb57",
   "metadata": {},
   "outputs": [],
   "source": [
    "gsm8k_processor = GSM8k()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "11d2de75",
   "metadata": {},
   "source": [
    "### Load and save the dataset "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "976681bd-4f43-4dbc-947e-cdb94d4824f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not os.path.exists(\"data\"):\n",
    "    os.mkdir(\"data\")\n",
    "    \n",
    "dataset = load_dataset(\"openai/gsm8k\", \"main\")\n",
    "num_samples = 0\n",
    "for dataset_type in ['train','test']:\n",
    "    data_list = []\n",
    "    for data in dataset[dataset_type]:\n",
    "        data_list.append({\"question\": data['question'], \"answer\": data['answer']})\n",
    "        if num_samples == 100 and dataset_type == 'train': # We sample only 100 train examples and use 25 out them for training randomly\n",
    "            break\n",
    "        num_samples += 1\n",
    "    gsm8k_processor.dataset_to_jsonl(\"data/\"+ dataset_type+'.jsonl', dataset=data_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ac30e74f",
   "metadata": {},
   "source": [
    "### Set paths"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f43482f1-3e10-4cf7-8ea6-ff42c04067a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_file_name = os.path.join(\"data\", \"train.jsonl\")\n",
    "test_file_name = os.path.join(\"data\", \"test.jsonl\")\n",
    "path_to_config = \"configs\"\n",
    "promptopt_config_path = os.path.join(path_to_config, \"promptopt_config.yaml\")\n",
    "setup_config_path = os.path.join(path_to_config, \"setup_config.yaml\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3392594d",
   "metadata": {},
   "source": [
    "### Create an object for calling prompt optimization and inference functionalities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8af4246f-db32-4b37-a73a-f9e2e5125d09",
   "metadata": {},
   "outputs": [],
   "source": [
    "gp = GluePromptOpt(promptopt_config_path,\n",
    "                   setup_config_path,\n",
    "                   train_file_name,\n",
    "                   gsm8k_processor)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1784648c",
   "metadata": {},
   "source": [
    "### Call prompt optmization function\n",
    "1. ```use_examples``` can be used when there are training samples and a mixture of real and synthetic in-context examples are required in the final prompt. When set to ```False``` all the in-context examples will be real\n",
    "2. ```generate_synthetic_examples``` can be used when there are no training samples and we want to generate synthetic examples \n",
    "3. ```run_without_train_examples``` can be used when there are no training samples and in-context examples are not required in the final prompt "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "573c6151-2c03-45d9-9904-1724a1e20f1b",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Function call to generate optimal prompt and expert profile \n",
    "best_prompt, expert_profile = gp.get_best_prompt(use_examples=True,run_without_train_examples=False,generate_synthetic_examples=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1ee1aa99",
   "metadata": {},
   "source": [
    "### Save the optimized prompt and expert profile"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34a716af-0d77-4c7d-b1c2-6438d66096ce",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import pickle \n",
    "\n",
    "if not os.path.exists(\"results\"):\n",
    "    os.system(\"mkdir results\")\n",
    "    \n",
    "with open(\"results/best_prompt.pkl\", 'wb') as f:\n",
    "    pickle.dump(best_prompt, f)\n",
    "with open(\"results/expert_profile.pkl\", 'wb') as f:\n",
    "    pickle.dump(expert_profile, f)\n",
    "\n",
    "print(f\"Best prompt: {best_prompt} \\nExpert profile: {expert_profile}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aac42eed",
   "metadata": {},
   "source": [
    "### Evaluate the optimized prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c49b5711-82dd-4d18-8cd4-ee447cf8d74c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "gp.EXPERT_PROFILE = expert_profile\n",
    "gp.BEST_PROMPT = best_prompt\n",
    "\n",
    "# Function call to evaluate the prompt\n",
    "accuracy = gp.evaluate(test_file_name)\n",
    "\n",
    "print(f\"Final Accuracy: {accuracy}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "general",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
